{
 "cells": [
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:51:54.684824Z",
     "start_time": "2025-09-02T02:51:54.678530Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self):\n",
    "        env = gym.make('CartPole-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(action)\n",
    "        done = terminated or truncated\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            done = True\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.03824893,  0.01038502,  0.02408452,  0.01904099], dtype=float32)"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 29
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:51:55.890621Z",
     "start_time": "2025-09-02T02:51:55.841096Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "#打印游戏\n",
    "def show():\n",
    "    plt.imshow(env.render())\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "show()"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ],
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAF7CAYAAAD4/3BBAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjUsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvWftoOwAAAAlwSFlzAAAPYQAAD2EBqD+naQAAJhBJREFUeJzt3Xl01NX9//H3ZCUQkjRANkiQTfaABYTUpWoiEZBKjT0uFKPlwJECR4gixiII9hiKPXUrhj/aij1HRPEIFhQ0BoEiETCSClFS4EsLliwITQLBrPP5nnu/v5lfRkP2zNzJPB89n04+87mZuXPNJC/uNjbLsiwBAAAwiJ+nKwAAAPB9BBQAAGAcAgoAADAOAQUAABiHgAIAAIxDQAEAAMYhoAAAAOMQUAAAgHEIKAAAwDgEFAAAYByPBpT169fLNddcIz169JDJkyfLoUOHPFkdAADg6wHlrbfekoyMDFm1apV88cUXMm7cOElNTZWysjJPVQkAABjC5qkPC1Q9JpMmTZI//vGP+txut0t8fLwsXrxYnnzySU9UCQAAGCLAE09aW1sr+fn5kpmZ6bzPz89PUlJSJC8v7wfla2pq9OGgwszFixelT58+YrPZ3FZvAADQfqpP5NKlSxIXF6f/7hsXUL799ltpaGiQ6Ohol/vV+fHjx39QPisrS1avXu3GGgIAgK5y9uxZGTBggHkBpa1UT4uar+JQUVEhCQkJ+gWGhYV5tG4AAKB1Kisr9XSO3r17t1jWIwGlb9++4u/vL6WlpS73q/OYmJgflA8ODtbH96lwQkABAMC7tGZ6hkdW8QQFBcmECRMkNzfXZV6JOk9KSvJElQAAgEE8NsSjhmzS09Nl4sSJcv3118uLL74oVVVV8vDDD3uqSgAAwNcDyr333ivnz5+XlStXSklJiYwfP1527dr1g4mzAADA93hsH5SOTrIJDw/Xk2WZgwIAQPf7+81n8QAAAOMQUAAAgHEIKAAAwDgEFAAAYBwCCgAAMA4BBQAAGIeAAgAAjENAAQAAxiGgAAAA4xBQAACAcQgoAADAOAQUAABgHAIKAAAwDgEFAAAYh4ACAACMQ0ABAADGIaAAAADjEFAAAIBxCCgAAMA4BBQAAGAcAgoAADAOAQUAABiHgAIAAIxDQAEAAMYhoAAAAOMQUAAAgHEIKAAAwDgEFAAAYBwCCgAAMA4BBQAAdP+A8swzz4jNZnM5RowY4bxeXV0tCxculD59+khoaKikpaVJaWlpZ1cDAAB4sS7pQRk9erQUFxc7j/379zuvLV26VLZv3y5btmyRvXv3yrlz5+Tuu+/uimoAAAAvFdAlDxoQIDExMT+4v6KiQv785z/Lpk2b5LbbbtP3vfbaazJy5Ej57LPPZMqUKV1RHQAA4GW6pAflxIkTEhcXJ4MHD5bZs2fLmTNn9P35+flSV1cnKSkpzrJq+CchIUHy8vKu+ng1NTVSWVnpcgAAgO6r0wPK5MmTZePGjbJr1y7Jzs6W06dPy0033SSXLl2SkpISCQoKkoiICJfviY6O1teuJisrS8LDw51HfHx8Z1cbAAB05yGeadOmOb9OTEzUgWXgwIHy9ttvS0hISLseMzMzUzIyMpznqgeFkAIAQPfV5cuMVW/JtddeKydPntTzUmpra6W8vNyljFrF09ScFYfg4GAJCwtzOQAAQPfV5QHl8uXLcurUKYmNjZUJEyZIYGCg5ObmOq8XFRXpOSpJSUldXRUAAOCrQzyPP/64zJw5Uw/rqCXEq1atEn9/f7n//vv1/JG5c+fq4ZrIyEjdE7J48WIdTljBAwAAuiygfPPNNzqMXLhwQfr16yc33nijXkKsvlZeeOEF8fPz0xu0qdU5qamp8uqrr3Z2NQAAgBezWZZliZdRk2RVb4zaV4X5KAAAdL+/33wWDwAAMA4BBQAAGIeAAgAAjENAAQAAxiGgAAAA4xBQAACAcQgoAADAOAQUAABgHAIKAAAwDgEFAAAYh4ACAACMQ0ABAADGIaAAAADjEFAAAIBxCCgAAMA4BBQAAGAcAgoAADAOAQUAABiHgAIAAIxDQAEAAMYhoAAAAOMQUAAAgHEIKAAAwDgEFAAAYBwCCgAAMA4BBQAAGIeAAgAAjENAAQAAxiGgAAAA4xBQAACAcQgoAADA+wPKvn37ZObMmRIXFyc2m022bdvmct2yLFm5cqXExsZKSEiIpKSkyIkTJ1zKXLx4UWbPni1hYWESEREhc+fOlcuXL3f81QAAAN8MKFVVVTJu3DhZv359k9fXrVsnL7/8smzYsEEOHjwovXr1ktTUVKmurnaWUeGksLBQcnJyZMeOHTr0zJ8/v2OvBAAAdBs2S3V5tPebbTbZunWrzJo1S5+rh1I9K4899pg8/vjj+r6KigqJjo6WjRs3yn333Sdff/21jBo1Sg4fPiwTJ07UZXbt2iXTp0+Xb775Rn9/SyorKyU8PFw/tuqFAQAA5mvL3+9OnYNy+vRpKSkp0cM6DqoikydPlry8PH2ubtWwjiOcKKq8n5+f7nFpSk1NjX5RjQ8AANB9dWpAUeFEUT0mjalzxzV1GxUV5XI9ICBAIiMjnWW+LysrSwcdxxEfH9+Z1QYAAIbxilU8mZmZujvIcZw9e9bTVQIAAN4SUGJiYvRtaWmpy/3q3HFN3ZaVlblcr6+v1yt7HGW+Lzg4WI9VNT4AAED31akBZdCgQTpk5ObmOu9T80XU3JKkpCR9rm7Ly8slPz/fWWb37t1it9v1XBUAAICAtn6D2q/k5MmTLhNjCwoK9ByShIQEWbJkifz2t7+VYcOG6cDy9NNP65U5jpU+I0eOlDvuuEPmzZunlyLX1dXJokWL9Aqf1qzgAQAA3V+bA8rnn38ut956q/M8IyND36anp+ulxE888YTeK0Xta6J6Sm688Ua9jLhHjx7O73njjTd0KElOTtard9LS0vTeKQAAAB3eB8VT2AcFAADv47F9UAAAADoDAQUAABiHgAIAAIxDQAEAAMYhoAAAAOMQUAAAgHEIKAAAwDgEFAAAYBwCCgAAMA4BBQAAGIeAAgAAjENAAQAAxiGgAAAA4xBQAACAcQgoAADAOAQUAABgHAIKAAAwDgEFAAAYh4ACAACMQ0ABAADGIaAAAADjEFAAAIBxCCgAAMA4BBQAAGAcAgoAADAOAQUAABiHgAIAAIxDQAEAAMYhoAAAAOMQUAAAgPcHlH379snMmTMlLi5ObDabbNu2zeX6Qw89pO9vfNxxxx0uZS5evCizZ8+WsLAwiYiIkLlz58rly5c7/moAAIBvBpSqqioZN26crF+//qplVCApLi52Hm+++abLdRVOCgsLJScnR3bs2KFDz/z589v3CgAAQLcT0NZvmDZtmj6aExwcLDExMU1e+/rrr2XXrl1y+PBhmThxor7vlVdekenTp8vvf/973TMDAAB8W5fMQdmzZ49ERUXJ8OHDZcGCBXLhwgXntby8PD2s4wgnSkpKivj5+cnBgwebfLyamhqprKx0OQAAQPfV6QFFDe/89a9/ldzcXPnd734ne/fu1T0uDQ0N+npJSYkOL40FBARIZGSkvtaUrKwsCQ8Pdx7x8fGdXW0AAODNQzwtue+++5xfjx07VhITE2XIkCG6VyU5Obldj5mZmSkZGRnOc9WDQkgBAKD76vJlxoMHD5a+ffvKyZMn9bmam1JWVuZSpr6+Xq/sudq8FTWnRa34aXwAAIDuq8sDyjfffKPnoMTGxurzpKQkKS8vl/z8fGeZ3bt3i91ul8mTJ3d1dQAAQHcc4lH7lTh6Q5TTp09LQUGBnkOijtWrV0taWpruDTl16pQ88cQTMnToUElNTdXlR44cqeepzJs3TzZs2CB1dXWyaNEiPTTECh4AAKDYLMuy2tIUai7Jrbfe+oP709PTJTs7W2bNmiVHjhzRvSQqcEydOlWeffZZiY6OdpZVwzkqlGzfvl2v3lGB5uWXX5bQ0NBW1UHNQVGTZSsqKhjuAQDAS7Tl73ebA4oJCCgAAHiftvz95rN4AACAcQgoAADAOAQUAABgHAIKAAAwDgEFAAAYh4ACAACMQ0ABAADGIaAAAADjEFAAAIBxCCgAAMD7PywQADrLv/++SWouX2i2TP9Jd0mvvgluqxMAMxBQAHiE3d4glcVFUv3f4mbLRY9JdludAJiDIR4AHmHZ7SJe91GlANyFgALAM+wNKqZ4uhYADEVAAeARlgooFgEFQNMIKAA8FlAselAAXAUBBYBHWJbqQfF0LQCYioACwCOsBuagALg6AgoAD/agEFAANI2AAsBjy4yZgwLgaggoADy4isfTtQBgKgIKAM8FFBIKgKsgoADwCPZBAdAcAgoAj2AfFADNIaAA8Ag+iwdAcwgoADy3zJiEAuAqCCgAPLZRm8UcFABXQUAB4BG1V8rFaqhvtox/cE/xCwh0W50AmIOAAsAjLhf/UxpqrzRbple/aySoV4Tb6gTAHAQUAMay2fxE1AHA57TpnZ+VlSWTJk2S3r17S1RUlMyaNUuKiopcylRXV8vChQulT58+EhoaKmlpaVJaWupS5syZMzJjxgzp2bOnfpxly5ZJfX3zXb0AfJCf+hVl83QtAJgeUPbu3avDx2effSY5OTlSV1cnU6dOlaqqKmeZpUuXyvbt22XLli26/Llz5+Tuu+92Xm9oaNDhpLa2Vg4cOCCvv/66bNy4UVauXNm5rwyA17PZ/MVmI6AAvshmdWAa/fnz53UPiAoiN998s1RUVEi/fv1k06ZNcs899+gyx48fl5EjR0peXp5MmTJFdu7cKXfeeacOLtHR0brMhg0bZPny5frxgoKCWnzeyspKCQ8P188XFhbW3uoD8KD/2f1nuXDiYLNlfjR4giT85F7moQDdRFv+fndocFc9gRIZGalv8/Pzda9KSkqKs8yIESMkISFBBxRF3Y4dO9YZTpTU1FRd6cLCwiafp6amRl9vfADo/mx+9KAAvqrdAcVut8uSJUvkhhtukDFjxuj7SkpKdA9IRITrv3ZUGFHXHGUahxPHdce1q819UYnLccTHx7e32gC8CJNkAd/V7ne+moty7Ngx2bx5s3S1zMxM3VvjOM6ePdvlzwnAlIBCDwrgiwLa802LFi2SHTt2yL59+2TAgAHO+2NiYvTk1/LycpdeFLWKR11zlDl06JDL4zlW+TjKfF9wcLA+APgYhngAn9WmHhQ1n1aFk61bt8ru3btl0KBBLtcnTJgggYGBkpub67xPLUNWy4qTkpL0ubo9evSolJWVOcuoFUFqssyoUaM6/ooAdBs2lhkDPiugrcM6aoXOe++9p/dCccwZUfNCQkJC9O3cuXMlIyNDT5xVoWPx4sU6lKgVPIpalqyCyJw5c2TdunX6MVasWKEfm14SAN8f4tHDPAB8TpsCSnZ2tr695ZZbXO5/7bXX5KGHHtJfv/DCC+Ln56c3aFOrb9QKnVdffdVZ1t/fXw8PLViwQAeXXr16SXp6uqxZs6ZzXhGAbrWKhzkogG/q0D4onsI+KIBv7IMSkzhV+k/6mfgFtLw/EgDzuW0fFADoUmoOCj0ogE8ioAAwe4iHSbKATyKgADAWk2QB38U7H4DbtXbqm15mzBAP4JMIKAA8olUZhd4TwGfx7gfgfpb9/45WYCdZwDcRUAC4nWXZWz3MA8A3EVAAuJ1lb30PCgDfREAB4H66B4WAAuDqCCgA3E6HE4Z4ADSDgALAI0M89KAAaA4BBYD7McQDoAUEFABup1fwMMQDoBkEFABuxxAPgJYQUAC4nWU1iKilxgBwFQQUAO5nWWzUBqBZBBQAbsdGbQBaQkAB4HZsdQ+gJQQUAEZ/WCAA30RAAeB29KAAaAkBBYDbscwYQEsIKADcrvbSt1JX9d9mywSEhEmPsH5uqxMAsxBQALhd7ZUKqa++3GyZwB6hEhQa6bY6ATALAQWAmWx+YrPxKwrwVbz7ARjJZrPpkALAN/HuB2BuD4ofv6IAX8W7H4CxPSgM8QC+i3c/ADOpIR56UACfxbsfgJFU7wk9KIDv4t0PwExqiIceFMBn8e4HYCR6UADf1qZ3f1ZWlkyaNEl69+4tUVFRMmvWLCkqKnIpc8stt/y/yW3//3jkkUdcypw5c0ZmzJghPXv21I+zbNkyqa+v75xXBKB7YJkx4NMC2lJ47969snDhQh1SVKB46qmnZOrUqfLVV19Jr169nOXmzZsna9ascZ6rIOLQ0NCgw0lMTIwcOHBAiouL5cEHH5TAwEB57rnnOut1AegOPSh+/p6uBgBvCCi7du1yOd+4caPuAcnPz5ebb77ZJZCoANKUjz76SAeajz/+WKKjo2X8+PHy7LPPyvLly+WZZ56RoKCg9r4WAN0Jy4wBn9ahd39FRYW+jYx0/byMN954Q/r27StjxoyRzMxMuXLlivNaXl6ejB07VocTh9TUVKmsrJTCwsImn6empkZfb3wA6N50OGGSLOCz2tSD0pjdbpclS5bIDTfcoIOIwwMPPCADBw6UuLg4+fLLL3XPiJqn8u677+rrJSUlLuFEcZyra1eb+7J69er2VhWAN6IHBfBp7Q4oai7KsWPHZP/+/S73z58/3/m16imJjY2V5ORkOXXqlAwZMqRdz6V6YTIyMpznqgclPj6+vVUH4AVYxQP4tna9+xctWiQ7duyQTz75RAYMGNBs2cmTJ+vbkydP6ls1N6W0tNSljOP8avNWgoODJSwszOUA0M2xDwrg09r07rcsS4eTrVu3yu7du2XQoEEtfk9BQYG+VT0pSlJSkhw9elTKysqcZXJycnToGDVqVNtfAQCvon6PiKijebr3hB4UwGcFtHVYZ9OmTfLee+/pvVAcc0bCw8MlJCRED+Oo69OnT5c+ffroOShLly7VK3wSExN1WbUsWQWROXPmyLp16/RjrFixQj+26ikB4APsLQcURe2jBMA3temfJ9nZ2XrljtqMTfWIOI633npLX1dLhNXyYRVCRowYIY899pikpaXJ9u3bnY/h7++vh4fUrepN+eUvf6n3QWm8bwqAbsyyxLIaPF0LAN2pB+X/umavTk1cVZu5tUSt8vnggw/a8tQAuglL/c9u93Q1ABiOAV4A7kUPCoBWIKAAcH9AoQcFQAsIKADcyrLsYtnpQQHQPAIKADdTPSgEFADNI6AA8MAcFIZ4ADSPgALArfRqQOagAGgBAQWAmzHEA6BlBBQA7sUQD4BWIKAAcPsQDz0oAFpCQAHgZgQUAC0joABwLybJAmgFAgoAtw/x2NnqHkALCCgA3IweFAAtI6AAcP8kWXpQALSAgALArequVMil4hPNlvEPCpHwhLFuqxMA8xBQALiV1VAv9rqaZsvYbH4S0KOX2+oEwDwEFADmsdnE5hfg6VoA8CACCgAj2fz8PV0FAB5EQAFgINWDQkABfBl9qADapL6+vkPf39CaJcY2EbvVsefy8/PTBwDvREAB0Cbjx4+XoqKidn9/4uAo2ZAxo9kyJSWl8uDNN8vxMxfa/TybN2+WtLS0dn8/AM8ioABok4aGhg71bKjvb81eKdU1dR16HjubwQFejYACwGMu1MZKRX0/aZAA6eFXJf2CzkgPv+/UXrPS0EDAAHwZAQWAR/zPlUQ5Wz1SvrP3Ekv8JcBWI99UD5frwnJErCqppwcE8GnMIAPgVpbY5Mx3I+XElYlyxR4ulv53kk3qrR5SXh8jB8rvlnrLX+rpQQF8GgEFgFuV10VLYdWNYr9KB26Nvaf8/b9pUl9PQAF8GQEFgAfYmr1mqSXG9KAAPo2AAsA4BBQABBQABrKYJAv4OAIKALcKDyyT4T0/E5s0HUACbLUyJXwbPSiAj2tTQMnOzpbExEQJCwvTR1JSkuzcudN5vbq6WhYuXCh9+vSR0NBQvYtjaWmpy2OcOXNGZsyYIT179pSoqChZtmxZh7fOBuA9/MQug0K+lMEhBRLsVyU2URu3qYXGtdLL/79yU8QWCbRVE1AAH9emfVAGDBgga9eulWHDhumdHl9//XW566675MiRIzJ69GhZunSpvP/++7JlyxYJDw+XRYsWyd133y2ffvqpcwdJFU5iYmLkwIEDUlxcLA8++KAEBgbKc88911WvEYBBLlR+J+99elxEjktZbYJcrIuVBitQQvwrJS7olOzyr5L/XvpOz0MB4LtslkoaHRAZGSnPP/+83HPPPdKvXz/ZtGmT/lo5fvy4jBw5UvLy8mTKlCm6t+XOO++Uc+fOSXR0tC6zYcMGWb58uZw/f16CgoJa9ZyVlZU6AD300EOt/h4AnePtt9+W8vJyMV1KSooMHjzY09UA0Ehtba1s3LhRKioq9EhMl+wkq3pDVE9JVVWVHurJz8+Xuro6/UvBYcSIEZKQkOAMKOp27NixznCipKamyoIFC6SwsFCuu+66Jp+rpqZGH40DijJnzhw9lATAfT788EOvCCjJycly2223eboaABq5fPmyDiit0eaAcvToUR1I1HwTFQ62bt0qo0aNkoKCAt2bERER4VJehZGSkhL9tbptHE4c1x3XriYrK0tWr179g/snTpzYYgID0LlCQkLEGwwZMkSuv/56T1cDQCOODoYuWcUzfPhwHUYOHjyoez7S09Plq6++kq6UmZmpu4Mcx9mzZ7v0+QAAgGe1uQdF9ZIMHTpUfz1hwgQ5fPiwvPTSS3LvvffqsSXV9du4F0Wt4lGTYhV1e+jQIZfHc6zycZRpSnBwsD4AAIBv6PA+KHa7Xc8PUWFFrcbJzc11XisqKtLLitWQkKJu1RBRWVmZs0xOTo4eplHDRAAAAG3uQVFDLdOmTdMTXy9duqRX7OzZs0dPmlOraubOnSsZGRl6ZY8KHYsXL9ahRE2QVaZOnaqDiJrcum7dOj3vZMWKFXrvFHpIAABAuwKK6vlQ+5ao/UtUIFGbtqlwcvvtt+vrL7zwgvj5+ekN2lSvilqh8+qrrzq/39/fX3bs2KHnrqjg0qtXLz2HZc2aNW2pBgAA6OY6vA+KJzj2QWnNOmoAnUvtbaT2OPKG/Vp+8YtfeLoaANr595vP4gEAAMYhoAAAAOMQUAAAgHEIKAAAwDjt/iweAL5Jfd6W+pwt0/Xv39/TVQDQAQQUAG3yyiuveLoKAHwAQzwAAMA4BBQAAGAcAgoAADAOAQUAABiHgAIAAIxDQAEAAMYhoAAAAOMQUAAAgHEIKAAAwDgEFAAAYBwCCgAAMA4BBQAAGIeAAgAAjENAAQAAxiGgAAAA4xBQAACAcQgoAADAOAQUAABgHAIKAAAwDgEFAAAYh4ACAACMQ0ABAADGIaAAAADjEFAAAIB3B5Ts7GxJTEyUsLAwfSQlJcnOnTud12+55Rax2WwuxyOPPOLyGGfOnJEZM2ZIz549JSoqSpYtWyb19fWd94oAAIDXC2hL4QEDBsjatWtl2LBhYlmWvP7663LXXXfJkSNHZPTo0brMvHnzZM2aNc7vUUHEoaGhQYeTmJgYOXDggBQXF8uDDz4ogYGB8txzz3Xm6wIAAF7MZqmk0QGRkZHy/PPPy9y5c3UPyvjx4+XFF19ssqzqbbnzzjvl3LlzEh0dre/bsGGDLF++XM6fPy9BQUGtes7KykoJDw+XiooK3ZMDAADM15a/3+2eg6J6QzZv3ixVVVV6qMfhjTfekL59+8qYMWMkMzNTrly54ryWl5cnY8eOdYYTJTU1VVe4sLDwqs9VU1OjyzQ+AABA99WmIR7l6NGjOpBUV1dLaGiobN26VUaNGqWvPfDAAzJw4ECJi4uTL7/8UveMFBUVybvvvquvl5SUuIQTxXGurl1NVlaWrF69uq1VBQAAvhJQhg8fLgUFBbp75p133pH09HTZu3evDinz5893llM9JbGxsZKcnCynTp2SIUOGtLuSqicmIyPDea56UOLj49v9eAAAwGxtHuJR80SGDh0qEyZM0D0b48aNk5deeqnJspMnT9a3J0+e1LdqcmxpaalLGce5unY1wcHBzpVDjgMAAHRfHd4HxW636zkiTVE9LYrqSVHU0JAaIiorK3OWycnJ0YHDMUwEAAAQ0NahlmnTpklCQoJcunRJNm3aJHv27JEPP/xQD+Oo8+nTp0ufPn30HJSlS5fKzTffrPdOUaZOnaqDyJw5c2TdunV63smKFStk4cKFupcEAACgzQFF9XyofUvU/iVqmZAKHiqc3H777XL27Fn5+OOP9RJjtbJHzRFJS0vTAcTB399fduzYIQsWLNC9Kb169dJzWBrvmwIAANDhfVA8gX1QAADwPm7ZBwUAAKCrEFAAAIBxCCgAAMA4BBQAAGAcAgoAADAOAQUAABiHgAIAAIxDQAEAAMYhoAAAAOMQUAAAgHEIKAAAwDgEFAAAYBwCCgAAMA4BBQAAGIeAAgAAjENAAQAAxiGgAAAA4xBQAACAcQgoAADAOAQUAABgHAIKAAAwDgEFAAAYh4ACAACMQ0ABAADGIaAAAADjEFAAAIBxCCgAAMA4BBQAAGAcAgoAADAOAQUAABiHgAIAAIxDQAEAAMYhoAAAAOMEiBeyLEvfVlZWeroqAACglRx/tx1/x7tdQLl06ZK+jY+P93RVAABAO/6Oh4eHN1vGZrUmxhjGbrdLUVGRjBo1Ss6ePSthYWGerpJXp1kV9GjHjqMtOw9t2Tlox85DW3YOFTlUOImLixM/P7/u14OiXlT//v311+oHhR+WjqMdOw9t2Xloy85BO3Ye2rLjWuo5cWCSLAAAMA4BBQAAGMdrA0pwcLCsWrVK36L9aMfOQ1t2Htqyc9COnYe2dD+vnCQLAAC6N6/tQQEAAN0XAQUAABiHgAIAAIxDQAEAAMbxyoCyfv16ueaaa6RHjx4yefJkOXTokKerZJx9+/bJzJkz9W59NptNtm3b5nJdzY1euXKlxMbGSkhIiKSkpMiJEydcyly8eFFmz56tNyWKiIiQuXPnyuXLl8WXZGVlyaRJk6R3794SFRUls2bN0rsYN1ZdXS0LFy6UPn36SGhoqKSlpUlpaalLmTNnzsiMGTOkZ8+e+nGWLVsm9fX14iuys7MlMTHRuclVUlKS7Ny503mdNmy/tWvX6vf4kiVLnPfRnq3zzDPP6LZrfIwYMcJ5nXb0MMvLbN682QoKCrL+8pe/WIWFhda8efOsiIgIq7S01NNVM8oHH3xg/eY3v7HeffddtUrL2rp1q8v1tWvXWuHh4da2bdusf/zjH9bPfvYza9CgQdZ3333nLHPHHXdY48aNsz777DPr73//uzV06FDr/vvvt3xJamqq9dprr1nHjh2zCgoKrOnTp1sJCQnW5cuXnWUeeeQRKz4+3srNzbU+//xza8qUKdZPfvIT5/X6+nprzJgxVkpKinXkyBH936Zv375WZmam5Sv+9re/We+//771z3/+0yoqKrKeeuopKzAwULerQhu2z6FDh6xrrrnGSkxMtB599FHn/bRn66xatcoaPXq0VVxc7DzOnz/vvE47epbXBZTrr7/eWrhwofO8oaHBiouLs7KysjxaL5N9P6DY7XYrJibGev755533lZeXW8HBwdabb76pz7/66iv9fYcPH3aW2blzp2Wz2az//Oc/lq8qKyvT7bJ3715nu6k/tFu2bHGW+frrr3WZvLw8fa5+afn5+VklJSXOMtnZ2VZYWJhVU1Nj+aof/ehH1p/+9CfasJ0uXbpkDRs2zMrJybF++tOfOgMK7dm2gKL+EdYU2tHzvGqIp7a2VvLz8/VwROPP5VHneXl5Hq2bNzl9+rSUlJS4tKP6bAQ1XOZoR3WrhnUmTpzoLKPKq/Y+ePCg+KqKigp9GxkZqW/Vz2NdXZ1LW6ou4oSEBJe2HDt2rERHRzvLpKam6g8fKywsFF/T0NAgmzdvlqqqKj3UQxu2jxp6UEMLjdtNoT3bRg1tq6HwwYMH6yFtNWSj0I6e51UfFvjtt9/qX26NfxgUdX78+HGP1cvbqHCiNNWOjmvqVo2nNhYQEKD/MDvK+Br1KdpqnP+GG26QMWPG6PtUWwQFBekw11xbNtXWjmu+4ujRozqQqHF9NZ6/detW/YnkBQUFtGEbqYD3xRdfyOHDh39wjZ/J1lP/KNu4caMMHz5ciouLZfXq1XLTTTfJsWPHaEcDeFVAATz9L1b1i2v//v2eropXUn8EVBhRvVDvvPOOpKeny969ez1dLa9z9uxZefTRRyUnJ0cvFED7TZs2zfm1msStAsvAgQPl7bff1osH4FleNcTTt29f8ff3/8EsanUeExPjsXp5G0dbNdeO6rasrMzlupqZrlb2+GJbL1q0SHbs2CGffPKJDBgwwHm/ags19FheXt5sWzbV1o5rvkL9a3To0KEyYcIEvTpq3Lhx8tJLL9GGbaSGHtR788c//rHu1VSHCnovv/yy/lr9C572bB/VW3LttdfKyZMn+bk0gJ+3/YJTv9xyc3Ndut3Vueo6RusMGjRIv3kat6MaM1VzSxztqG7VG1P9MnTYvXu3bm/1rwxfoeYYq3CihiPU61dt15j6eQwMDHRpS7UMWY1jN25LNbzROPCpf/2q5bZqiMNXqZ+lmpoa2rCNkpOTdVuo3ijHoeaKqfkTjq9pz/ZR2yicOnVKb7/Az6UBLC9cZqxWm2zcuFGvNJk/f75eZtx4FjX+b4a/WvamDvWf+Q9/+IP++t///rdzmbFqt/fee8/68ssvrbvuuqvJZcbXXXeddfDgQWv//v16xYCvLTNesGCBXo69Z88el6WIV65ccVmKqJYe7969Wy9FTEpK0sf3lyJOnTpVL1XetWuX1a9fP59aivjkk0/qlU+nT5/WP2/qXK0I++ijj/R12rBjGq/iUWjP1nnsscf0e1v9XH766ad6ubBaJqxW6ym0o2d5XUBRXnnlFf1Do/ZDUcuO1T4dcPXJJ5/oYPL9Iz093bnU+Omnn7aio6N14EtOTtb7UzR24cIFHUhCQ0P1srmHH35YBx9f0lQbqkPtjeKgQt2vf/1rvWy2Z8+e1s9//nMdYhr717/+ZU2bNs0KCQnRvwDVL8a6ujrLV/zqV7+yBg4cqN+z6he4+nlzhBOFNuzcgEJ7ts69995rxcbG6p/L/v376/OTJ086r9OOnmVT/+fpXhwAAACvnYMCAAB8AwEFAAAYh4ACAACMQ0ABAADGIaAAAADjEFAAAIBxCCgAAMA4BBQAAGAcAgoAADAOAQUAABiHgAIAAIxDQAEAAGKa/wUsLB8CpxW1GQAAAABJRU5ErkJggg=="
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "execution_count": 30
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:51:56.859049Z",
     "start_time": "2025-09-02T02:51:56.820493Z"
    }
   },
   "cell_type": "code",
   "source": [
    "#测试游戏环境\n",
    "def test_env():\n",
    "    state = env.reset()\n",
    "    print('这个游戏的状态用4个数字表示,我也不知道这4个数字分别是什么意思,反正这4个数字就能描述游戏全部的状态')\n",
    "    print('state=', state)\n",
    "    #state= [ 0.03490619  0.04873464  0.04908862 -0.00375859]\n",
    "\n",
    "    print('这个游戏一共有2个动作,不是0就是1')\n",
    "    print('env.action_space=', env.action_space)\n",
    "    #env.action_space= Discrete(2)\n",
    "\n",
    "    print('随机一个动作')\n",
    "    action = env.action_space.sample()\n",
    "    print('action=', action)\n",
    "    #action= 1\n",
    "\n",
    "    print('执行一个动作,得到下一个状态,奖励,是否结束')\n",
    "    state, reward, over, _ = env.step(action)\n",
    "\n",
    "    print('state=', state)\n",
    "    #state= [ 0.02018229 -0.16441101  0.01547085  0.2661691 ]\n",
    "\n",
    "    print('reward=', reward)\n",
    "    #reward= 1.0\n",
    "\n",
    "    print('over=', over)\n",
    "    #over= False\n",
    "\n",
    "\n",
    "test_env()"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "这个游戏的状态用4个数字表示,我也不知道这4个数字分别是什么意思,反正这4个数字就能描述游戏全部的状态\n",
      "state= [ 0.04897398  0.02883779 -0.01016652  0.02206425]\n",
      "这个游戏一共有2个动作,不是0就是1\n",
      "env.action_space= Discrete(2)\n",
      "随机一个动作\n",
      "action= 1\n",
      "执行一个动作,得到下一个状态,奖励,是否结束\n"
     ]
    },
    {
     "ename": "AttributeError",
     "evalue": "module 'numpy' has no attribute 'bool8'",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mAttributeError\u001B[0m                            Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[31], line 30\u001B[0m\n\u001B[1;32m     26\u001B[0m     \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mover=\u001B[39m\u001B[38;5;124m'\u001B[39m, over)\n\u001B[1;32m     27\u001B[0m     \u001B[38;5;66;03m#over= False\u001B[39;00m\n\u001B[0;32m---> 30\u001B[0m \u001B[43mtest_env\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n",
      "Cell \u001B[0;32mIn[31], line 18\u001B[0m, in \u001B[0;36mtest_env\u001B[0;34m()\u001B[0m\n\u001B[1;32m     15\u001B[0m \u001B[38;5;66;03m#action= 1\u001B[39;00m\n\u001B[1;32m     17\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124m执行一个动作,得到下一个状态,奖励,是否结束\u001B[39m\u001B[38;5;124m'\u001B[39m)\n\u001B[0;32m---> 18\u001B[0m state, reward, over, _ \u001B[38;5;241m=\u001B[39m \u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mstep\u001B[49m\u001B[43m(\u001B[49m\u001B[43maction\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     20\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mstate=\u001B[39m\u001B[38;5;124m'\u001B[39m, state)\n\u001B[1;32m     21\u001B[0m \u001B[38;5;66;03m#state= [ 0.02018229 -0.16441101  0.01547085  0.2661691 ]\u001B[39;00m\n",
      "Cell \u001B[0;32mIn[29], line 19\u001B[0m, in \u001B[0;36mMyWrapper.step\u001B[0;34m(self, action)\u001B[0m\n\u001B[1;32m     18\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21mstep\u001B[39m(\u001B[38;5;28mself\u001B[39m, action):\n\u001B[0;32m---> 19\u001B[0m     state, reward, terminated, truncated, info \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mstep\u001B[49m\u001B[43m(\u001B[49m\u001B[43maction\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     20\u001B[0m     done \u001B[38;5;241m=\u001B[39m terminated \u001B[38;5;129;01mor\u001B[39;00m truncated\n\u001B[1;32m     21\u001B[0m     \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mstep_n \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;241m1\u001B[39m\n",
      "File \u001B[0;32m~/PycharmProjects/Simple_Reinforcement_Learning/.venv/lib/python3.10/site-packages/gym/wrappers/time_limit.py:50\u001B[0m, in \u001B[0;36mTimeLimit.step\u001B[0;34m(self, action)\u001B[0m\n\u001B[1;32m     39\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21mstep\u001B[39m(\u001B[38;5;28mself\u001B[39m, action):\n\u001B[1;32m     40\u001B[0m \u001B[38;5;250m    \u001B[39m\u001B[38;5;124;03m\"\"\"Steps through the environment and if the number of steps elapsed exceeds ``max_episode_steps`` then truncate.\u001B[39;00m\n\u001B[1;32m     41\u001B[0m \n\u001B[1;32m     42\u001B[0m \u001B[38;5;124;03m    Args:\u001B[39;00m\n\u001B[0;32m   (...)\u001B[0m\n\u001B[1;32m     48\u001B[0m \n\u001B[1;32m     49\u001B[0m \u001B[38;5;124;03m    \"\"\"\u001B[39;00m\n\u001B[0;32m---> 50\u001B[0m     observation, reward, terminated, truncated, info \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mstep\u001B[49m\u001B[43m(\u001B[49m\u001B[43maction\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     51\u001B[0m     \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_elapsed_steps \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;241m1\u001B[39m\n\u001B[1;32m     53\u001B[0m     \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_elapsed_steps \u001B[38;5;241m>\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_max_episode_steps:\n",
      "File \u001B[0;32m~/PycharmProjects/Simple_Reinforcement_Learning/.venv/lib/python3.10/site-packages/gym/wrappers/order_enforcing.py:37\u001B[0m, in \u001B[0;36mOrderEnforcing.step\u001B[0;34m(self, action)\u001B[0m\n\u001B[1;32m     35\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_has_reset:\n\u001B[1;32m     36\u001B[0m     \u001B[38;5;28;01mraise\u001B[39;00m ResetNeeded(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mCannot call env.step() before calling env.reset()\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[0;32m---> 37\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mstep\u001B[49m\u001B[43m(\u001B[49m\u001B[43maction\u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[0;32m~/PycharmProjects/Simple_Reinforcement_Learning/.venv/lib/python3.10/site-packages/gym/wrappers/env_checker.py:37\u001B[0m, in \u001B[0;36mPassiveEnvChecker.step\u001B[0;34m(self, action)\u001B[0m\n\u001B[1;32m     35\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mchecked_step \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mFalse\u001B[39;00m:\n\u001B[1;32m     36\u001B[0m     \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mchecked_step \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mTrue\u001B[39;00m\n\u001B[0;32m---> 37\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43menv_step_passive_checker\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43maction\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     38\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m     39\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39menv\u001B[38;5;241m.\u001B[39mstep(action)\n",
      "File \u001B[0;32m~/PycharmProjects/Simple_Reinforcement_Learning/.venv/lib/python3.10/site-packages/gym/utils/passive_env_checker.py:233\u001B[0m, in \u001B[0;36menv_step_passive_checker\u001B[0;34m(env, action)\u001B[0m\n\u001B[1;32m    230\u001B[0m obs, reward, terminated, truncated, info \u001B[38;5;241m=\u001B[39m result\n\u001B[1;32m    232\u001B[0m \u001B[38;5;66;03m# np.bool is actual python bool not np boolean type, therefore bool_ or bool8\u001B[39;00m\n\u001B[0;32m--> 233\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(terminated, (\u001B[38;5;28mbool\u001B[39m, \u001B[43mnp\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbool8\u001B[49m)):\n\u001B[1;32m    234\u001B[0m     logger\u001B[38;5;241m.\u001B[39mwarn(\n\u001B[1;32m    235\u001B[0m         \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mExpects `terminated` signal to be a boolean, actual type: \u001B[39m\u001B[38;5;132;01m{\u001B[39;00m\u001B[38;5;28mtype\u001B[39m(terminated)\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m    236\u001B[0m     )\n\u001B[1;32m    237\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(truncated, (\u001B[38;5;28mbool\u001B[39m, np\u001B[38;5;241m.\u001B[39mbool8)):\n",
      "File \u001B[0;32m~/PycharmProjects/Simple_Reinforcement_Learning/.venv/lib/python3.10/site-packages/numpy/__init__.py:414\u001B[0m, in \u001B[0;36m__getattr__\u001B[0;34m(attr)\u001B[0m\n\u001B[1;32m    411\u001B[0m     \u001B[38;5;28;01mimport\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01mnumpy\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mchar\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mas\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01mchar\u001B[39;00m\n\u001B[1;32m    412\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m char\u001B[38;5;241m.\u001B[39mchararray\n\u001B[0;32m--> 414\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mAttributeError\u001B[39;00m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mmodule \u001B[39m\u001B[38;5;132;01m{!r}\u001B[39;00m\u001B[38;5;124m has no attribute \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m    415\u001B[0m                      \u001B[38;5;124m\"\u001B[39m\u001B[38;5;132;01m{!r}\u001B[39;00m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;241m.\u001B[39mformat(\u001B[38;5;18m__name__\u001B[39m, attr))\n",
      "\u001B[0;31mAttributeError\u001B[0m: module 'numpy' has no attribute 'bool8'"
     ]
    }
   ],
   "execution_count": 31
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:51:58.276117Z",
     "start_time": "2025-09-02T02:51:58.270752Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import torch\n",
    "\n",
    "#计算动作的模型,也是真正要用的模型\n",
    "model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 2),\n",
    ")\n",
    "\n",
    "#经验网络,用于评估一个状态的分数\n",
    "next_model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 2),\n",
    ")\n",
    "\n",
    "#把model的参数复制给next_model\n",
    "next_model.load_state_dict(model.state_dict())\n",
    "\n",
    "model, next_model"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Sequential(\n",
       "   (0): Linear(in_features=4, out_features=128, bias=True)\n",
       "   (1): ReLU()\n",
       "   (2): Linear(in_features=128, out_features=2, bias=True)\n",
       " ),\n",
       " Sequential(\n",
       "   (0): Linear(in_features=4, out_features=128, bias=True)\n",
       "   (1): ReLU()\n",
       "   (2): Linear(in_features=128, out_features=2, bias=True)\n",
       " ))"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 32
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:52:07.929909Z",
     "start_time": "2025-09-02T02:52:07.924898Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import random\n",
    "\n",
    "\n",
    "#得到一个动作\n",
    "def get_action(state):\n",
    "    if random.random() < 0.01:\n",
    "        return random.choice([0, 1])\n",
    "\n",
    "    #走神经网络,得到一个动作\n",
    "    state = torch.FloatTensor(state).reshape(1, 4)\n",
    "\n",
    "    return model(state).argmax().item()\n",
    "\n",
    "\n",
    "get_action([0.0013847, -0.01194451, 0.04260966, 0.00688801])"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 33
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:52:08.657842Z",
     "start_time": "2025-09-02T02:52:08.645632Z"
    }
   },
   "cell_type": "code",
   "source": [
    "#样本池\n",
    "datas = []\n",
    "\n",
    "\n",
    "#向样本池中添加N条数据,删除M条最古老的数据\n",
    "def update_data():\n",
    "    old_count = len(datas)\n",
    "\n",
    "    #玩到新增了N个数据为止\n",
    "    while len(datas) - old_count < 200:\n",
    "        #初始化游戏\n",
    "        state = env.reset()\n",
    "\n",
    "        #玩到游戏结束为止\n",
    "        over = False\n",
    "        while not over:\n",
    "            #根据当前状态得到一个动作\n",
    "            action = get_action(state)\n",
    "\n",
    "            #执行动作,得到反馈\n",
    "            next_state, reward, over, _ = env.step(action)\n",
    "\n",
    "            #记录数据样本\n",
    "            datas.append((state, action, reward, next_state, over))\n",
    "\n",
    "            #更新游戏状态,开始下一个动作\n",
    "            state = next_state\n",
    "\n",
    "    update_count = len(datas) - old_count\n",
    "    drop_count = max(len(datas) - 10000, 0)\n",
    "\n",
    "    #数据上限,超出时从最古老的开始删除\n",
    "    while len(datas) > 10000:\n",
    "        datas.pop(0)\n",
    "\n",
    "    return update_count, drop_count\n",
    "\n",
    "\n",
    "update_data(), len(datas)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((200, 0), 200)"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 34
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:52:09.332098Z",
     "start_time": "2025-09-02T02:52:09.321065Z"
    }
   },
   "cell_type": "code",
   "source": [
    "#获取一批数据样本\n",
    "def get_sample():\n",
    "    #从样本池中采样\n",
    "    samples = random.sample(datas, 64)\n",
    "\n",
    "    #[b, 4]\n",
    "    state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 4)\n",
    "    #[b, 1]\n",
    "    action = torch.LongTensor([i[1] for i in samples]).reshape(-1, 1)\n",
    "    #[b, 1]\n",
    "    reward = torch.FloatTensor([i[2] for i in samples]).reshape(-1, 1)\n",
    "    #[b, 4]\n",
    "    next_state = torch.FloatTensor([i[3] for i in samples]).reshape(-1, 4)\n",
    "    #[b, 1]\n",
    "    over = torch.LongTensor([i[4] for i in samples]).reshape(-1, 1)\n",
    "\n",
    "    return state, action, reward, next_state, over\n",
    "\n",
    "\n",
    "state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "state, action, reward, next_state, over"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[ 9.7867e-02,  1.4053e+00, -1.5834e-01, -2.1787e+00],\n",
       "         [ 1.5729e-01,  1.5895e+00, -1.9095e-01, -2.5295e+00],\n",
       "         [ 7.7323e-02,  1.4003e+00, -1.2129e-01, -2.1432e+00],\n",
       "         [-3.8111e-02,  3.6296e-01, -1.0551e-02, -5.6489e-01],\n",
       "         [ 1.5088e-01,  1.5837e+00, -1.9772e-01, -2.5024e+00],\n",
       "         [-1.8079e-02,  2.3543e-02, -3.9243e-05, -5.4255e-03],\n",
       "         [ 1.4370e-02,  1.1451e+00, -9.1574e-02, -1.7765e+00],\n",
       "         [-1.3235e-02,  4.1379e-01, -6.1102e-03, -5.9085e-01],\n",
       "         [ 4.4009e-02,  2.3759e-02, -1.4068e-02, -4.8160e-02],\n",
       "         [-4.2861e-03,  4.2702e-01, -8.1905e-03, -5.9502e-01],\n",
       "         [ 1.8957e-02,  9.5947e-01, -9.2494e-02, -1.5079e+00],\n",
       "         [ 5.1338e-02,  6.0362e-01, -4.4947e-02, -9.0275e-01],\n",
       "         [ 7.2209e-03,  8.0436e-01, -3.5636e-02, -1.1837e+00],\n",
       "         [ 3.9188e-02,  1.1853e+00, -6.4408e-02, -1.7336e+00],\n",
       "         [ 3.8584e-02,  1.7062e-02, -2.7019e-02,  2.2794e-03],\n",
       "         [-2.1324e-02,  4.2930e-01, -2.0892e-02, -6.1607e-01],\n",
       "         [ 3.3386e-02,  1.3768e+00, -1.6663e-01, -2.1704e+00],\n",
       "         [-2.5679e-02,  7.8835e-01, -7.5387e-02, -1.2031e+00],\n",
       "         [ 8.1943e-02,  1.3347e+00, -1.6918e-01, -2.2163e+00],\n",
       "         [ 3.8273e-02,  1.1922e+00, -8.4012e-02, -1.8167e+00],\n",
       "         [ 1.0936e-01,  1.3504e+00, -1.4430e-01, -2.1744e+00],\n",
       "         [ 6.0721e-02,  1.4081e+00, -1.4325e-01, -2.1657e+00],\n",
       "         [ 4.4723e-02, -1.2875e-02,  9.5922e-04,  1.5271e-02],\n",
       "         [ 8.6274e-02,  1.1543e+00, -1.0729e-01, -1.8504e+00],\n",
       "         [ 3.5185e-02,  1.1777e+00, -6.1807e-02, -1.7206e+00],\n",
       "         [-3.3895e-02,  3.7518e-02, -2.4050e-03, -1.1485e-02],\n",
       "         [ 6.9350e-02,  8.0541e-01, -5.3632e-02, -1.2456e+00],\n",
       "         [ 6.7107e-02,  7.6785e-01, -3.2934e-02, -1.1610e+00],\n",
       "         [-1.9852e-04,  7.8696e-01, -1.1024e-02, -1.1215e+00],\n",
       "         [-2.2130e-02,  1.7362e-01, -3.2418e-02, -2.7856e-01],\n",
       "         [ 8.1494e-02,  1.6021e+00, -1.7115e-01, -2.4584e+00],\n",
       "         [ 8.9881e-02,  1.5843e+00, -1.6303e-01, -2.4616e+00],\n",
       "         [-1.7608e-02,  2.1867e-01, -1.4775e-04, -2.9812e-01],\n",
       "         [-1.9977e-02,  3.9703e-01,  1.6450e-02, -5.4314e-01],\n",
       "         [ 7.7141e-02,  1.5583e+00, -2.0075e-01, -2.4923e+00],\n",
       "         [ 1.4393e-02,  5.5046e-01, -5.6447e-02, -9.3856e-01],\n",
       "         [-8.5652e-03,  2.2772e-01,  7.3989e-03, -3.3512e-01],\n",
       "         [-1.8260e-02, -1.9071e-02, -3.4175e-02,  2.4841e-02],\n",
       "         [ 4.1303e-02,  1.1260e+00, -4.3356e-02, -1.7565e+00],\n",
       "         [-2.3642e-02, -1.5674e-01, -1.9621e-02,  2.7705e-01],\n",
       "         [ 6.1258e-02,  1.3518e+00, -1.5921e-01, -2.1561e+00],\n",
       "         [ 5.8740e-02,  1.3735e+00, -9.6219e-02, -2.0318e+00],\n",
       "         [-4.0109e-03,  4.2273e-01,  6.9641e-04, -6.2546e-01],\n",
       "         [-2.6777e-02,  3.8653e-02, -1.4080e-02, -2.1756e-02],\n",
       "         [ 4.3177e-02,  4.0806e-01, -3.2950e-02, -5.9987e-01],\n",
       "         [ 6.7221e-02,  1.3917e+00, -1.2502e-01, -2.1169e+00],\n",
       "         [ 4.1155e-03,  1.5912e-01, -3.7234e-02, -3.2822e-01],\n",
       "         [ 1.2206e-01,  1.7732e+00, -1.8734e-01, -2.6995e+00],\n",
       "         [ 6.2095e-02,  8.1584e-01, -7.4715e-02, -1.2052e+00],\n",
       "         [-4.9590e-03,  6.0900e-01, -1.7927e-02, -8.8545e-01],\n",
       "         [ 7.3685e-02,  1.2091e+00, -1.2132e-01, -1.8509e+00],\n",
       "         [ 3.8925e-02,  2.1256e-01, -2.6973e-02, -2.9880e-01],\n",
       "         [-2.6003e-02,  2.3397e-01, -1.4515e-02, -3.1885e-01],\n",
       "         [ 5.9174e-02,  1.1385e+00, -1.3146e-01, -1.8859e+00],\n",
       "         [ 4.8456e-03, -3.6505e-02, -3.6750e-02, -2.4177e-02],\n",
       "         [ 5.7154e-02,  6.0983e-01, -3.4789e-02, -9.4214e-01],\n",
       "         [ 2.2689e-02,  9.3071e-01, -1.4168e-02, -1.4594e+00],\n",
       "         [-1.8657e-02,  3.6919e-01, -3.7989e-02, -5.8129e-01],\n",
       "         [ 8.4394e-02,  1.5451e+00, -2.0094e-01, -2.4910e+00],\n",
       "         [-3.7528e-02,  5.9250e-01, -5.7529e-02, -8.9291e-01],\n",
       "         [-3.3145e-02,  2.3267e-01, -2.6347e-03, -3.0493e-01],\n",
       "         [ 4.3307e-02,  1.1957e+00, -8.9058e-02, -1.7980e+00],\n",
       "         [-1.9935e-02,  6.2308e-01, -2.0702e-02, -8.9386e-01],\n",
       "         [-7.4730e-03,  8.1847e-01, -3.8579e-02, -1.1930e+00]]),\n",
       " tensor([[1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1]]),\n",
       " tensor([[1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.]]),\n",
       " tensor([[ 1.2597e-01,  1.6016e+00, -2.0191e-01, -2.5158e+00],\n",
       "         [ 1.8908e-01,  1.7856e+00, -2.4154e-01, -2.8741e+00],\n",
       "         [ 1.0533e-01,  1.5964e+00, -1.6416e-01, -2.4708e+00],\n",
       "         [-3.0852e-02,  5.5822e-01, -2.1849e-02, -8.6088e-01],\n",
       "         [ 1.8255e-01,  1.7799e+00, -2.4777e-01, -2.8486e+00],\n",
       "         [-1.7608e-02,  2.1867e-01, -1.4775e-04, -2.9812e-01],\n",
       "         [ 3.7272e-02,  1.3411e+00, -1.2710e-01, -2.0962e+00],\n",
       "         [-4.9590e-03,  6.0900e-01, -1.7927e-02, -8.8545e-01],\n",
       "         [ 4.4484e-02,  2.1908e-01, -1.5032e-02, -3.4525e-01],\n",
       "         [ 4.2542e-03,  6.2225e-01, -2.0091e-02, -8.9028e-01],\n",
       "         [ 3.8147e-02,  1.1556e+00, -1.2265e-01, -1.8280e+00],\n",
       "         [ 6.3410e-02,  7.9933e-01, -6.3002e-02, -1.2092e+00],\n",
       "         [ 2.3308e-02,  9.9992e-01, -5.9311e-02, -1.4874e+00],\n",
       "         [ 6.2894e-02,  1.3811e+00, -9.9080e-02, -2.0456e+00],\n",
       "         [ 3.8925e-02,  2.1256e-01, -2.6973e-02, -2.9880e-01],\n",
       "         [-1.2738e-02,  6.2471e-01, -3.3213e-02, -9.1526e-01],\n",
       "         [ 6.0923e-02,  1.5731e+00, -2.1004e-01, -2.5096e+00],\n",
       "         [-9.9115e-03,  9.8436e-01, -9.9449e-02, -1.5184e+00],\n",
       "         [ 1.0864e-01,  1.5310e+00, -2.1350e-01, -2.5560e+00],\n",
       "         [ 6.2118e-02,  1.3882e+00, -1.2035e-01, -2.1343e+00],\n",
       "         [ 1.3637e-01,  1.5466e+00, -1.8779e-01, -2.5079e+00],\n",
       "         [ 8.8882e-02,  1.6043e+00, -1.8657e-01, -2.4989e+00],\n",
       "         [ 4.4465e-02,  1.8223e-01,  1.2646e-03, -2.7711e-01],\n",
       "         [ 1.0936e-01,  1.3504e+00, -1.4430e-01, -2.1744e+00],\n",
       "         [ 5.8740e-02,  1.3735e+00, -9.6219e-02, -2.0318e+00],\n",
       "         [-3.3145e-02,  2.3267e-01, -2.6347e-03, -3.0493e-01],\n",
       "         [ 8.5458e-02,  1.0012e+00, -7.8543e-02, -1.5545e+00],\n",
       "         [ 8.2464e-02,  9.6339e-01, -5.6154e-02, -1.4638e+00],\n",
       "         [ 1.5541e-02,  9.8222e-01, -3.3455e-02, -1.4176e+00],\n",
       "         [-1.8657e-02,  3.6919e-01, -3.7989e-02, -5.8129e-01],\n",
       "         [ 1.1354e-01,  1.7982e+00, -2.2031e-01, -2.7983e+00],\n",
       "         [ 1.2157e-01,  1.7804e+00, -2.1227e-01, -2.7996e+00],\n",
       "         [-1.3235e-02,  4.1379e-01, -6.1102e-03, -5.9085e-01],\n",
       "         [-1.2037e-02,  5.9191e-01,  5.5876e-03, -8.3060e-01],\n",
       "         [ 1.0831e-01,  1.7544e+00, -2.5060e-01, -2.8392e+00],\n",
       "         [ 2.5402e-02,  7.4630e-01, -7.5218e-02, -1.2484e+00],\n",
       "         [-4.0109e-03,  4.2273e-01,  6.9641e-04, -6.2546e-01],\n",
       "         [-1.8641e-02,  1.7652e-01, -3.3678e-02, -2.7843e-01],\n",
       "         [ 6.3823e-02,  1.3216e+00, -7.8486e-02, -2.0623e+00],\n",
       "         [-2.6777e-02,  3.8653e-02, -1.4080e-02, -2.1756e-02],\n",
       "         [ 8.8295e-02,  1.5481e+00, -2.0234e-01, -2.4934e+00],\n",
       "         [ 8.6210e-02,  1.5695e+00, -1.3686e-01, -2.3527e+00],\n",
       "         [ 4.4438e-03,  6.1785e-01, -1.1813e-02, -9.1793e-01],\n",
       "         [-2.6003e-02,  2.3397e-01, -1.4515e-02, -3.1885e-01],\n",
       "         [ 5.1338e-02,  6.0362e-01, -4.4947e-02, -9.0275e-01],\n",
       "         [ 9.5055e-02,  1.5878e+00, -1.6736e-01, -2.4455e+00],\n",
       "         [ 7.2980e-03,  3.5476e-01, -4.3798e-02, -6.3241e-01],\n",
       "         [ 1.5752e-01,  1.9691e+00, -2.4133e-01, -3.0430e+00],\n",
       "         [ 7.8411e-02,  1.0118e+00, -9.8818e-02, -1.5203e+00],\n",
       "         [ 7.2209e-03,  8.0436e-01, -3.5636e-02, -1.1837e+00],\n",
       "         [ 9.7867e-02,  1.4053e+00, -1.5834e-01, -2.1787e+00],\n",
       "         [ 4.3177e-02,  4.0806e-01, -3.2950e-02, -5.9987e-01],\n",
       "         [-2.1324e-02,  4.2930e-01, -2.0892e-02, -6.1607e-01],\n",
       "         [ 8.1943e-02,  1.3347e+00, -1.6918e-01, -2.2163e+00],\n",
       "         [ 4.1155e-03,  1.5912e-01, -3.7234e-02, -3.2822e-01],\n",
       "         [ 6.9350e-02,  8.0541e-01, -5.3632e-02, -1.2456e+00],\n",
       "         [ 4.1303e-02,  1.1260e+00, -4.3356e-02, -1.7565e+00],\n",
       "         [-1.1273e-02,  5.6482e-01, -4.9615e-02, -8.8569e-01],\n",
       "         [ 1.1530e-01,  1.7413e+00, -2.5076e-01, -2.8380e+00],\n",
       "         [-2.5679e-02,  7.8835e-01, -7.5387e-02, -1.2031e+00],\n",
       "         [-2.8491e-02,  4.2783e-01, -8.7332e-03, -5.9844e-01],\n",
       "         [ 6.7221e-02,  1.3917e+00, -1.2502e-01, -2.1169e+00],\n",
       "         [-7.4730e-03,  8.1847e-01, -3.8579e-02, -1.1930e+00],\n",
       "         [ 8.8964e-03,  1.0141e+00, -6.2439e-02, -1.4975e+00]]),\n",
       " tensor([[0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1],\n",
       "         [1],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0]]))"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 35
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:52:10.298661Z",
     "start_time": "2025-09-02T02:52:10.294272Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def get_value(state, action):\n",
    "    #使用状态计算出动作的logits\n",
    "    #[b, 4] -> [b, 2]\n",
    "    value = model(state)\n",
    "\n",
    "    #根据实际使用的action取出每一个值\n",
    "    #这个值就是模型评估的在该状态下,执行动作 的分数\n",
    "    #在执行动作前,显然并不知道会得到的反馈和next_state\n",
    "    #所以这里不能也不需要考虑next_state和reward\n",
    "    #[b, 2] -> [b, 1]\n",
    "    value = value.gather(dim=1, index=action)\n",
    "\n",
    "    return value\n",
    "\n",
    "\n",
    "get_value(state, action).shape"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 1])"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 36
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:52:10.971163Z",
     "start_time": "2025-09-02T02:52:10.965647Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def get_target(reward, next_state, over):\n",
    "    #上面已经把模型认为的状态下执行动作的分数给评估出来了\n",
    "    #下面使用next_state和reward计算真实的分数\n",
    "    #针对一个状态,它到底应该多少分,可以使用以往模型积累的经验评估\n",
    "    #这也是没办法的办法,因为显然没有精确解,这里使用延迟更新的next_model评估\n",
    "\n",
    "    #使用next_state计算下一个状态的分数\n",
    "    #[b, 4] -> [b, 2]\n",
    "    with torch.no_grad():\n",
    "        target = next_model(next_state)\n",
    "\n",
    "    #取所有动作中分数最大的\n",
    "    #[b, 2] -> [b, 1]\n",
    "    target = target.max(dim=1)[0]\n",
    "    target = target.reshape(-1, 1)\n",
    "\n",
    "    #下一个状态的分数乘以一个系数,相当于权重\n",
    "    target *= 0.98\n",
    "\n",
    "    #如果next_state已经游戏结束,则next_state的分数是0\n",
    "    #因为如果下一步已经游戏结束,显然不需要再继续玩下去,也就不需要考虑next_state了.\n",
    "    #[b, 1] * [b, 1] -> [b, 1]\n",
    "    target *= (1 - over)\n",
    "\n",
    "    #加上reward就是最终的分数\n",
    "    #[b, 1] + [b, 1] -> [b, 1]\n",
    "    target += reward\n",
    "\n",
    "    return target\n",
    "\n",
    "\n",
    "get_target(reward, next_state, over).shape"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 1])"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 37
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:52:11.693844Z",
     "start_time": "2025-09-02T02:52:11.687409Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from IPython import display\n",
    "\n",
    "\n",
    "def test(play):\n",
    "    #初始化游戏\n",
    "    state = env.reset()\n",
    "\n",
    "    #记录反馈值的和,这个值越大越好\n",
    "    reward_sum = 0\n",
    "\n",
    "    #玩到游戏结束为止\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据当前状态得到一个动作\n",
    "        action = get_action(state)\n",
    "\n",
    "        #执行动作,得到反馈\n",
    "        state, reward, over, _ = env.step(action)\n",
    "        reward_sum += reward\n",
    "\n",
    "        #打印动画\n",
    "        if play and random.random() < 0.2:  #跳帧\n",
    "            display.clear_output(wait=True)\n",
    "            show()\n",
    "\n",
    "    return reward_sum\n",
    "\n",
    "\n",
    "test(play=False)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9.0"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 38
  },
  {
   "metadata": {
    "jupyter": {
     "is_executing": true
    },
    "ExecuteTime": {
     "start_time": "2025-09-02T02:52:12.778382Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def train():\n",
    "    model.train()\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)\n",
    "    loss_fn = torch.nn.MSELoss()\n",
    "\n",
    "    #训练N次\n",
    "    for epoch in range(500):\n",
    "        #更新N条数据\n",
    "        update_count, drop_count = update_data()\n",
    "\n",
    "        #每次更新过数据后,学习N次\n",
    "        for i in range(200):\n",
    "            #采样一批数据\n",
    "            state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "            #计算一批样本的value和target\n",
    "            value = get_value(state, action)\n",
    "            target = get_target(reward, next_state, over)\n",
    "\n",
    "            #更新参数\n",
    "            loss = loss_fn(value, target)\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            #把model的参数复制给next_model\n",
    "            if (i + 1) % 10 == 0:\n",
    "                next_model.load_state_dict(model.state_dict())\n",
    "\n",
    "        if epoch % 50 == 0:\n",
    "            test_result = sum([test(play=False) for _ in range(20)]) / 20\n",
    "            print(epoch, len(datas), update_count, drop_count, test_result)\n",
    "\n",
    "\n",
    "train()"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 407 207 0 11.1\n",
      "50 10000 299 299 135.4\n",
      "100 10000 200 200 193.6\n",
      "150 10000 286 286 196.65\n",
      "200 10000 374 374 174.75\n",
      "250 10000 200 200 200.0\n",
      "300 10000 200 200 200.0\n",
      "350 10000 200 200 200.0\n",
      "400 10000 333 333 190.35\n"
     ]
    }
   ],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "test(play=True)\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第7章-DQN算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python [conda env:pt39]",
   "language": "python",
   "name": "conda-env-pt39-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
